//
// Copyright © 2018-2021 Arm Limited.
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "../include/ethosn_support_library/Support.hpp"
#include "BufferManager.hpp"
#include "Compiler.hpp"
#include "Graph.hpp"
#include "StrategyConfig.hpp"
#include "WeightEncoder.hpp"

#include <ethosn_command_stream/CommandData.hpp>
#include <ethosn_command_stream/CommandStreamBuffer.hpp>

#include <cmath>
#include <cstdint>
#include <memory>
#include <unordered_map>

namespace ethosn
{
namespace support_library
{

class CompiledNetwork;
class IStrategy;
class Pass;
class McePlePass;
class SoftmaxPass;
class Compiler;

// Structure holding SramOffsets for tensors
struct SramOffsets
{
    uint32_t inputOffset;
    uint32_t outputOffset;
    uint32_t weightOffset;
    uint32_t pleCodeOffset;
};

command_stream::DataLocation GetCommandDataLocation(BufferLocation bufferLocation);

class Pass
{
public:
    Pass(const HardwareCapabilities& capabilities, size_t id)
        : m_Id(id)
        , m_IsGenerated(false)
        , m_IsEstimated(false)
        , m_Capabilities(capabilities)
        , m_Nodes()
        , m_Section(nullptr)
        , m_CommandStreamFirstCommandIdx(0)
        , m_CommandStreamLastCommandIdx(0)
    {}
    virtual ~Pass()
    {}

    size_t GetId() const
    {
        return m_Id;
    }

    bool IsGenerated() const
    {
        return m_IsGenerated;
    }

    bool IsEstimated() const
    {
        return m_IsEstimated;
    }

    const std::vector<Node*>& GetNodes() const
    {
        return m_Nodes;
    }

    void SetSection(Section* section)
    {
        m_Section = section;
    }

    Section* GetSection() const
    {
        return m_Section;
    }

    /// Generates this Pass by adding appropriate entries to the given command stream, memory map and buffer table.
    virtual void
        Generate(command_stream::CommandStreamBuffer& cmdStream, BufferManager& bufferManager, bool dumpRam) = 0;

    /// Estimate performance of this Pass.
    void Estimate(std::vector<PassPerformanceData>& perfStream, const EstimationOptions& estimationOptions);

    /// Generates section command to the given command stream
    void PreGenerate(command_stream::CommandStreamBuffer& cmdStream);

    /// Generates dump command (if needed) to the given command stream
    void PostGenerate(command_stream::CommandStreamBuffer& cmdStream, bool dumpRam);

    virtual DotAttributes GetDotAttributes();

protected:
    std::set<uint32_t> GetCorrespondingOperationIds() const;

    /// Performance estimation functions
    /// @{
    virtual PassStats GetStats(const EstimationOptions& estimationOptions) = 0;
    /// @}

    size_t m_Id;
    bool m_IsGenerated;
    bool m_IsEstimated;
    const HardwareCapabilities& m_Capabilities;
    std::vector<Node*> m_Nodes;
    Section* m_Section;

    /// The range of commands in the command stream that were generated by this pass.
    /// These are set during PreGenerate and PostGenerate and dumped to the dot files
    /// for debugging purposes.
    /// @{
    uint32_t m_CommandStreamFirstCommandIdx;
    uint32_t m_CommandStreamLastCommandIdx;
    /// @}
};

template <typename TNode>
TNode* GetNextLinearNodeForInclusionInPass(Node* source)
{
    // Check that our last operation's output is not used by anything else, so that we can merge more operations
    // without affecting any other subsequent operations.
    if (source->GetOutputs().size() != 1)
    {
        // Our output is used by other operations too, so cannot fuse without breaking them.
        return nullptr;
    }

    TNode* next = dynamic_cast<TNode*>(source->GetOutput(0)->GetDestination());

    if (next->GetPass() != nullptr)
    {
        return nullptr;
    }

    return next;
}

ConcatNode* FindConcatNode(Node* node);

std::pair<TensorShape, TensorShape> CalculateConcatSupertensorInfo(Node* inputToConcat, ConcatNode* concatNode);

}    // namespace support_library
}    // namespace ethosn
